library(RItools)
library(Matching)
library(coin)
library(exactRankTests)

balance.check <- function(dat, vars.to.eval, vars.names = NULL, vars.type, tr.var, tr.type = "binary", tr.names = NULL, 
                          graph.qq = TRUE, plot.file.qq = NULL, graph.pvals = TRUE, plot.file.pvals = NULL, 
                          text.xvalue = -.1, cex.pval.axis = .75, rnd = 3){
  
  if(tr.type == "binary"){
    plot.cols <- ceiling(sqrt(length(vars.to.eval)))
    plot.rows <- ceiling(length(vars.to.eval)/plot.cols)
    if(is.null(vars.names)){
		  vars.names <- vars.to.eval
	  }
	
	  pvals <- data.frame(matrix(NA, length(vars.to.eval), 3))
	  colnames(pvals) <- c("ks.p", "tt.p", "wt.p")
	
  	if(!(is.null(plot.file.qq))){
	  	pdf(file = plot.file.qq)
	  }	
    if(graph.qq == TRUE){
      par(mfrow = c(plot.rows, plot.cols))
    }
  
    for(i in 1:length(vars.to.eval)){

			current.var <- vars.to.eval[i]
			current.name <- vars.names[i]
			tr.levels <- sort(unique(dat[, tr.var]))
			if(is.null(tr.names)){
				tr.names <- tr.levels
			} 

			x.vec <- dat[dat[, tr.var] == tr.levels[1], current.var]
			y.vec <- dat[dat[, tr.var] == tr.levels[2], current.var] 
		
			##KS
			ks.pval <- ks.boot(x.vec, y.vec)$ks.boot.pvalue
			pvals$ks.p[i] <- ks.pval
			##t-test
			tt.pval <- t.test(x.vec, y.vec)$p.value
			pvals$tt.p[i] <- tt.pval
			## Wilcox
			wt.pval <- wilcox.exact(x.vec, y.vec)$p.value
			fff <- as.formula(paste(current.var, " ~ as.factor(", tr.var, ")", sep=""))
			wt.pval <- pvalue(independence_test(fff, data = dat, ytrafo = rank, distribution = exact()))
			pvals$wt.p[i] <- wt.pval
		
			if(graph.qq == TRUE){

				if(vars.type[i] == "cont"){
					qqplot(x.vec, y.vec, xlim = c(min(dat[, current.var], na.rm = TRUE), max(dat[, current.var], na.rm = TRUE)), ylim = c(min(dat[, current.var], na.rm = TRUE), max(dat[, current.var], na.rm = TRUE)), xlab = tr.names[1], ylab = tr.names[2], main = current.name)
					par(new=TRUE)
					abline(0,1, col= "red")
				}
				if(vars.type[i] == "disc"){
					mosaicplot(~ dat[, current.var] + dat[, tr.var], xlab = current.name, ylab = "Treatment", main = "", col= TRUE)
				}	
			}	
		}
		if(!(is.null(plot.file.qq))){
			dev.off()
		}
		
	  xv <- NULL
	  for(xxxv in 1:(length(vars.to.eval)-1)){
		  xv <- paste(xv, vars.to.eval[xxxv], "+", sep="")
	  }
	  xv <- paste(xv, vars.to.eval[length(vars.to.eval)], sep="")
	  ffff <- as.formula(paste("I(as.numeric(as.factor(", tr.var, "))) ~ ", xv, sep=""))
	  d2.p <- xBalance(ffff, data = dat, report = c("chisquare.test"))$overall$p.value

	  if(graph.pvals == TRUE){
		  if(!(is.null(plot.file.pvals))){
			  pdf(file = plot.file.pvals)
		  }
		  plot(-1,-1, xlim = c(0,1), ylim = c(0,1), xlab = "p-values", ylab = "", axes = FALSE)
      axis(1)
      axis(2, at = (length(vars.to.eval)- 0:length(vars.to.eval) +1)/(length(vars.to.eval)+2), 
         labels = c(expression(italic(d)^2), vars.names), tick = FALSE, cex.axis = cex.pval.axis, las = 1, mgp = c(3, -.8, 0))
		  abline(v=.05, lty=2, col= "grey")
		  abline(v=.1, lty=2, col= "grey")
		
      d2.p.yvalue <- (length(vars.to.eval)+ 1)/(length(vars.to.eval) + 2)
   	  points(d2.p, y = d2.p.yvalue, pch = 19, col = "red")
		
      for(i in 1:length(vars.to.eval)){
			  points(pvals[i, ], y = rep((length(vars.to.eval)-i+1)/(length(vars.to.eval)+2), 3), pch = 2:4)
		  }
			
		  if(!(is.null(plot.file.pvals))){
			  dev.off()	
		  }	
	  }
 
	  storage <- list(pvals = pvals, d2.p = d2.p)
	  for(i in 1:length(storage)){
		  storage[[i]] <- round(storage[[i]], rnd)
	  }
  }  
  
  if(tr.type == "continuous"){

    pvals <- array(NA, length(vars.to.eval))
    for(i in 1:length(vars.to.eval)){  
      if(vars.type[i] == "cont"){
        pvals[i] <- summary(lm(dat[, vars.to.eval[i]] ~ dat[, tr.var]))$coefficients[2, 4]
      }
      if((vars.type[i] == "disc") & (nlevels(as.factor(dat[, vars.to.eval[i]])) == 2)){
        pvals[i] <- summary(glm(dat[, vars.to.eval[i]] ~ dat[, tr.var], family = "binomial"))$coefficients[2, 4]
      }
    }
    storage <- list(pvals = pvals, d2.p = NA)
  }
  return(storage)
}
